{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LMDB creation\n",
    "This script creates databases (in LMDB format) from collections of images.\n",
    "There are two functions, one for creating an LMDB based on RGB-like images\n",
    "and one for creating an LMDB based on a dense ground truth (i.e. 2D labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Imports and stuff\n",
    "import numpy as np\n",
    "import lmdb\n",
    "import glob\n",
    "from skimage import img_as_float, io\n",
    "from random import shuffle\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have to import several variables from the config file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from config import BGR, DATASET_DIR, CAFFE_ROOT,\\\n",
    "                   data_lmdbs, test_lmdbs, label_lmdbs, test_label_lmdbs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sys.path.insert(0, CAFFE_ROOT + 'python/')\n",
    "import caffe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def list_images(folder, pattern='*', ext='png'):\n",
    "    \"\"\"List the images in a specified folder by pattern and extension\n",
    "\n",
    "    Args:\n",
    "        folder (str): folder containing the images to list\n",
    "        pattern (str, optional): a bash-like pattern of the files to select\n",
    "                                 defaults to * (everything)\n",
    "        ext(str, optional): the image extension (defaults to png)\n",
    "\n",
    "    Returns:\n",
    "        str list: list of (filenames) images matching the pattern in the folder\n",
    "    \"\"\"\n",
    "    filenames = sorted(glob.glob(folder + pattern + '.' + ext))\n",
    "    return filenames"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define two functions for the LMDB creation : one for the images and one for the labels. The labels are 2D matrices (HxW), whereas the images have a 3D shape (CxHxW)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Adapted from https://github.com/BVLC/caffe/issues/1698\n",
    "def create_image_lmdb(target, samples, bgr=BGR, normalize=False):\n",
    "    \"\"\"Create an image LMDB\n",
    "\n",
    "    Args:\n",
    "        target (str): path of the LMDB to be created\n",
    "        samples (array list): list of images to be included in the LMDB\n",
    "        bgr (bool): True if we want to reverse the channel order (RGB->BGR)\n",
    "        normalize (bool): True if we want to normalize data in [0,1]\n",
    "    \"\"\"\n",
    "\n",
    "    # Open the LMDB\n",
    "    if os.path.isdir(target):\n",
    "        raise Exception(\"LMDB already exists in {}, aborted.\".format(target))\n",
    "    db = lmdb.open(target, map_size=int(1e12))\n",
    "    with db.begin(write=True) as txn:\n",
    "        for idx, sample in tqdm(enumerate(samples), total=len(samples)):\n",
    "            sample = io.imread(sample)\n",
    "            # load image:\n",
    "            if normalize:\n",
    "                # - in [0,1.]float range\n",
    "                sample = img_as_float(sample)\n",
    "            if bgr:\n",
    "                # - in BGR (reverse from RGB)\n",
    "                sample = sample[:,:,::-1]\n",
    "            # - in Channel x Height x Width order (switch from H x W x C)\n",
    "            sample = sample.transpose((2,0,1))\n",
    "            datum = caffe.io.array_to_datum(sample)\n",
    "            # Write the data into the db\n",
    "            txn.put('{:0>10d}'.format(idx), datum.SerializeToString())\n",
    "\n",
    "    db.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def create_label_lmdb(target, labels):\n",
    "    \"\"\"Create an image LMDB\n",
    "\n",
    "    Args:\n",
    "        target (str): path of the LMDB to be created\n",
    "        labels (array list): list of 2D-labels to be included in the LMDB\n",
    "    \"\"\"\n",
    "    if os.path.isdir(target):\n",
    "        raise Exception(\"LMDB already exists in {}, aborted.\".format(target))\n",
    "    db = lmdb.open(target, map_size=int(1e12))\n",
    "    percentage = 0\n",
    "    with db.begin(write=True) as txn:\n",
    "        for idx, label in tqdm(enumerate(labels), total=len(labels)):\n",
    "            label = io.imread(label)\n",
    "            # Add a singleton third dimension\n",
    "            label = label.reshape(label.shape + (1,))\n",
    "            # Switch to Channel x Height x Width order\n",
    "            label = label.transpose((2,0,1))\n",
    "            datum = caffe.io.array_to_datum(label)\n",
    "            # Write the data into the db\n",
    "            txn.put('{:0>10d}'.format(idx), datum.SerializeToString())\n",
    "    db.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Get the RNG state to always shuffle the same way\n",
    "RNG_STATE = np.random.get_state()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We run the code for our data lmdbs and labels lmdbs (train and test each time)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Create each LMDB\n",
    "for source_folder, target_folder in tqdm(data_lmdbs):\n",
    "    print \"=== Creating LMDB for {} ===\".format(source_folder)\n",
    "    sys.stdout.flush()\n",
    "    np.random.set_state(RNG_STATE)\n",
    "    samples = list_images(source_folder)\n",
    "    np.random.shuffle(samples)\n",
    "    create_image_lmdb(target_folder, samples, bgr=True)\n",
    "\n",
    "for source_folder, target_folder in tqdm(label_lmdbs):\n",
    "    print \"=== Creating LMDB for {} ===\".format(source_folder)\n",
    "    sys.stdout.flush()\n",
    "    np.random.set_state(RNG_STATE)\n",
    "    samples = list_images(source_folder)\n",
    "    np.random.shuffle(samples)\n",
    "    create_label_lmdb(target_folder, samples)\n",
    "\n",
    "for source_folder, target_folder in tqdm(test_lmdbs):\n",
    "    print \"=== Creating LMDB for {} ===\".format(source_folder)\n",
    "    sys.stdout.flush()\n",
    "    samples = list_images(source_folder)\n",
    "    create_image_lmdb(target_folder, samples)\n",
    "\n",
    "for source_folder, target_folder in tqdm(test_label_lmdbs):\n",
    "    print \"=== Creating LMDB for {} ===\".format(source_folder)\n",
    "    sys.stdout.flush()\n",
    "    samples = list_images(source_folder)\n",
    "    create_label_lmdb(target_folder, samples)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
